# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Gist training script, adapted from huggingface's run_clm.py example.
"""

import logging
import os
import pickle
import hydra
import torch  # noqa
import torch.nn as nn
import numpy as np

from omegaconf.dictconfig import DictConfig
from transformers import (
    AutoConfig,
    AutoTokenizer,
    LlamaTokenizer,
    is_torch_tpu_available,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version

from . import gist_llama, gist_t5
from .arguments import Arguments, global_setup
from .data import alpaca
from .data.utils import nested_select
from .gist_llama import DEBUG_LLAMA_CONFIG, GistLlamaForCausalLM
from .gist_t5 import GistT5ForConditionalGeneration
from .integrations import CustomWandbCallback, EvaluateFirstStepCallback
from .metrics import get_compute_metrics_fn
from .trainer_seq2seq import GistSeq2SeqTrainer
from .get_data import get_dataset

from peft import get_peft_model, LoraConfig, PromptTuningConfig, PrefixTuningConfig, PromptEncoderConfig, TaskType, PromptTuningInit

# Will error if the minimal version of Transformers is not installed. Remove at
# your own risks.
check_min_version("4.28.0.dev0")

require_version(
    "datasets>=1.8.0",
    "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt",
)

logger = logging.getLogger(__name__)
torch.cuda.empty_cache()

NORM_RATIO = 7.8

def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
    )
    
@hydra.main(config_path="conf", config_name="config")
def main(args: DictConfig) -> None:
    args: Arguments = global_setup(args)
    # args.training.disable_tqdm = True

    inference = True

    if inference:
        args.training.num_train_epochs = 1
        args.training.per_device_eval_batch_size = 16

    # Detect last checkpoint
    last_checkpoint = None
    if (
        os.path.isdir(args.training.output_dir)
        and args.training.do_train
        and not args.training.overwrite_output_dir
    ):
        last_checkpoint = get_last_checkpoint(args.training.output_dir)
        if last_checkpoint is None and len(os.listdir(args.training.output_dir)) > 0:
            existing_files = os.listdir(args.training.output_dir)
            logger.warning(
                (
                    "Output directory (%s) already exists and "
                    "is not empty. Existing files: %s. "
                    "Training anyways as these may just be output files."
                ),
                args.training.output_dir,
                str(existing_files),
            )
        elif (
            last_checkpoint is not None and args.training.resume_from_checkpoint is None
        ):
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To "
                "avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from "
                "scratch."
            )

    # Set seed before initializing model
    set_seed(args.training.seed)

    config_kwargs = {
        "cache_dir": args.model.cache_dir,
        "revision": args.model.model_revision,
        "use_auth_token": True if args.model.use_auth_token else None,
    }

    if args.model.llama_debug:
        if args.model.pretrained:
            raise RuntimeError("llama_debug requires pretrained set to False")
        config = DEBUG_LLAMA_CONFIG
    elif args.model.config_name:
        config = AutoConfig.from_pretrained(args.model.config_name, **config_kwargs)
    elif args.model.model_name_or_path:
        config = AutoConfig.from_pretrained(
            args.model.model_name_or_path, **config_kwargs
        )
    else:
        raise ValueError(
            "Unlike run_clm.py, this script does not support specifying a model type "
            "from scratch. Specify args.model.model_name_or_path and set "
            "args.pretrained = False to train from scratch instead."
        )

    is_t5 = any(t in args.model.model_name_or_path.lower() for t in ("t5", "tk"))
    is_llama = any(t in args.model.model_name_or_path.lower() for t in ("llama",))


    ###### Token dict handling ######

    task_name = args.model.task_name    

    # Load pretrained special token dict and weights
    if args.model.token_dict_path is not None:
        with open(args.model.token_dict_path + '/token_name_dict.pkl', 'rb') as f:
            token_name_dict = pickle.load(f)

        embedding_weights = torch.from_numpy(np.load(args.model.token_dict_path + "/embedder_weights.npy")).float()
        num_existing_tokens = embedding_weights.shape[0]
    else:
        token_name_dict = {}
        embedding_weights = None
        num_existing_tokens = 0
        args.model.use_end_marker = False
        
    freeze = args.model.freeze_existing_tokens
    is_7b = (args.model.model_name_or_path[-2:] == "7b") 
    if args.model.peft:
        args.model.regression = (args.model.peft == "linearprobe")
        args.model.use_start_marker = False
        args.model.use_end_marker = False 
        args.model.use_functional_token = False
        args.model.add_ce_loss = False 
        args.model.inverse_prompting = False 
        args.model.autoregressive_attn_mask = True
    if args.model.linear_probe:
        args.model.regression = True
        args.model.use_start_marker = False
        args.model.use_end_marker = False 
        args.model.use_functional_token = False
        args.model.add_ce_loss = False 
        args.model.inverse_prompting = False 
        args.model.autoregressive_attn_mask = True

    train_dataset, eval_dataset, token_name_dict, num_new_tokens, update_tokens, start_markers = get_dataset(task_name, num_existing_tokens, token_name_dict, args.model.num_token_per_prompt, args.model.use_start_marker, args.model.use_end_marker, args.model.use_functional_token, args.model.use_scalar_encode, args.model.inverse_prompting, args.model.regression, freeze, is_7b)
    config.update({"num_new_tokens": num_new_tokens, "output_dir": args.training.output_dir, "regression_out_dim": args.model.regression_out_dim})
    print("Output dir:", args.training.output_dir, "\nToken dict:", token_name_dict, "\nTokens to learn:", update_tokens)
    print("Train data size:", len(train_dataset))
    
    if not os.path.exists(args.training.output_dir):
        os.makedirs(args.training.output_dir)
    with open(args.training.output_dir + "/token_name_dict.pkl", 'wb') as f:
        pickle.dump(token_name_dict, f)
    with open(args.training.output_dir + "/arguments.pkl", 'wb') as f:
        pickle.dump(args, f)

    ############

    tokenizer_kwargs = {
        "cache_dir": args.model.cache_dir,
        "use_fast": args.model.use_fast_tokenizer,
        "revision": args.model.model_revision,
        "use_auth_token": True if args.model.use_auth_token else None,
    }

    if args.model.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            args.model.tokenizer_name, **tokenizer_kwargs
        )
    elif args.model.model_name_or_path:
        if is_llama:
            tokenizer = LlamaTokenizer.from_pretrained(
                args.model.model_name_or_path, **tokenizer_kwargs
            )
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.padding_side = "left"
        else:
            tokenizer = AutoTokenizer.from_pretrained(
                args.model.model_name_or_path, **tokenizer_kwargs
            )
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported "
            "by this script."
            "You can do it from another script, save it, and load it from here, using "
            "--tokenizer_name."
        )

    if is_t5:
        model_cls = GistT5ForConditionalGeneration
    elif is_llama:
        model_cls = GistLlamaForCausalLM
    else:
        raise ValueError(f"Model type {args.model.model_name_or_path} not supported")
    if args.model.pretrained:
        model = model_cls.from_pretrained(
            args.model.model_name_or_path,
            from_tf=bool(".ckpt" in args.model.model_name_or_path),
            config=config,
            cache_dir=args.model.cache_dir,
            revision=args.model.model_revision,
            use_auth_token=True if args.model.use_auth_token else None,
        )
    else:
        model = model_cls(config)
    avg_emb = model.model.embed_tokens.weight.data.mean(0).clone()

    if args.model.peft:
        model.word_embeddings = model.model.embed_tokens
        if args.model.peft == "lora":
            peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05)
        elif args.model.peft == "prompttuning":
            if task_name == "DC" or task_name == "QED" or task_name == "Descriptor": 
                num_prompt = 2
            elif task_name == "BA":
                num_prompt = 3
            else:
                num_prompt = 1
                
            peft_config = PromptTuningConfig(
                    task_type=TaskType.CAUSAL_LM,
                    prompt_tuning_init=PromptTuningInit.RANDOM,
                    num_virtual_tokens=args.model.num_token_per_prompt * num_prompt,
                    tokenizer_name_or_path=args.model.model_name_or_path,
                )
        elif args.model.peft == "prefixtuning":
            peft_config = PrefixTuningConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, num_virtual_tokens=args.model.num_token_per_prompt)
        if args.model.peft != "linearprobe":
            model = get_peft_model(model, peft_config)
    else:
        ###### Embedding weight initialization ######
    
        model.model.augmented_embedder.embedding.weight.data = nn.Parameter(torch.stack([avg_emb * NORM_RATIO] * num_new_tokens))
    
        if embedding_weights is not None:
            if not freeze:
                model.model.augmented_embedder.embedding = nn.Embedding.from_pretrained(torch.cat([embedding_weights, model.model.augmented_embedder.embedding.weight.data]))
            else:
                model.model.augmented_embedder.original_embedder = nn.Embedding.from_pretrained(torch.cat([model.model.augmented_embedder.original_embedder.weight.data, embedding_weights ]))
    
        if freeze:
            model.model.augmented_embedder.vocab_size = model.model.augmented_embedder.original_embedder.weight.data.shape[0]
            model.model.augmented_embedder.added_tokens = [model.model.augmented_embedder.vocab_size + i for i in range(num_new_tokens)]
        else:
            model.model.augmented_embedder.added_tokens = [model.model.augmented_embedder.vocab_size + i for i in range(num_existing_tokens + num_new_tokens)]
     
        ############

    ###### Regression input/output initialization ######
    if args.model.regression:
        if True:
            if args.model.tied_weights:
                model.lm_head_reg = nn.Linear(config.hidden_size, num_new_tokens, bias=False)
                model.lm_head_reg.weight = model.model.augmented_embedder.embedding.weight
            else:
                model.lm_head_reg = nn.Linear(config.hidden_size, args.model.regression_out_dim, bias=False)
                model.lm_head_reg.weight = nn.Parameter(torch.stack([avg_emb] * args.model.regression_out_dim))
                if args.model.use_scalar_encode:
                    model.model.augmented_embedder.encode_reg = nn.Linear(args.model.regression_out_dim, config.hidden_size, bias=False)
                    model.model.augmented_embedder.encode_reg.weight = nn.Parameter(avg_emb.unsqueeze(-1))        
    
    ############
    if (not args.model.peft):
        # Freeze original weight
        for name, module in model.named_children():
            for param in module.parameters():
                param.requires_grad = False

        # Set update parameters
        for param in model.model.augmented_embedder.embedding.parameters():
            param.requires_grad = True
    
        if args.model.regression:
            for param in model.lm_head_reg.parameters():
                param.requires_grad = True
            if args.model.use_scalar_encode:
                for param in model.model.augmented_embedder.encode_reg.parameters():
                    param.requires_grad = True
                    
    if args.model.peft == "linearprobe":
        for name, module in model.named_children():
            for param in module.parameters():
                param.requires_grad = False
        for param in model.lm_head_reg.parameters():
            param.requires_grad = True
    
    # Check grad setting
    for name, module in model.named_children():
        for n, param in module.named_parameters():
            if param.requires_grad:
                print(name, n, param.shape, param)

    # Check if special token has already been added to the model (e.g. because
    # we're resuming from a checkpoint.)
    
    if is_t5 and len(tokenizer) == gist_t5.PRETRAINED_VOCAB_SIZE + num_new_tokens:
        assert model.shared.weight.shape[0] == gist_t5.PRETRAINED_VOCAB_SIZE + num_new_tokens
    elif is_llama and len(tokenizer) == gist_llama.PRETRAINED_VOCAB_SIZE + num_new_tokens:
        assert (
            model.model.embed_tokens.weight.shape[0]
            == gist_llama.PRETRAINED_VOCAB_SIZE + num_new_tokens
        )
        assert model.lm_head.weight.shape[0] == gist_llama.PRETRAINED_VOCAB_SIZE + num_new_tokens
    else:

        # Add tokens to tokenizer
        tokenizer.add_special_tokens({"additional_special_tokens": ["<GIST " + str(i) + ">" for i in range(num_existing_tokens + num_new_tokens)]})        
        
    special_tokens = tokenizer.additional_special_tokens_ids

    if args.training.do_train:
        if args.data.max_train_samples is not None:
            print("Truncate training dataset to size", args.data.max_train_samples)
            max_train_samples = min(len(train_dataset), args.data.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))

    if args.training.do_eval:
        
        if args.data.max_eval_samples is not None:
            eval_dataset = nested_select(
                eval_dataset,
                args.data.max_eval_samples,
            )

        compute_metrics = get_compute_metrics_fn(
            gist_token=special_tokens, tokenizer=tokenizer, args=args
        )

    print_trainable_parameters(model)
    ###### Data setting ######
    start_markers = np.array(start_markers) + gist_llama.PRETRAINED_VOCAB_SIZE
    
    if is_t5:
        data_collator = alpaca.collator.DataCollatorForAlpaca(
            tokenizer,
            model=model,
            padding="longest",
            # Chosen so that <1% of examples are truncated.
            # See data/alpaca_plus/length_stats.txt for length stats.
            max_source_length=128,
            max_target_length=256,
            # Human eval examples are longer.
            max_source_length_human=384,
            max_target_length_human=384,
            label_pad_token_id=-100,
            pad_to_multiple_of=8 if args.training.fp16 else None,
            gist_condition=args.training.gist.condition,
            num_gist_tokens=args.training.gist.num_gist_tokens,
            gist_token=special_tokens,
            pad_token=tokenizer.pad_token_id,
            add_gist_token=args.training.gist.add_gist_token,
        )
    elif is_llama:
        if is_7b:
            maxl = 704 
        else:
            maxl = 256 + 128
        if inference:
            maxl = 1024
        # This data collator variant does causal language modeling with left
        # padding.
        data_collator = alpaca.collator.DataCollatorForAlpacaCLM(
            tokenizer,
            icl_dataset=train_dataset if args.model.icl_method is not None else None,
            method=args.model.icl_method,
            num_demonstrations=args.model.icl_num_demonstrations,
            idx_dict=np.load(args.model.icl_idx_dict_path) if args.model.icl_idx_dict_path is not None else None,
            # Chosen so that <1% of examples are truncated.
            # See data/alpaca_plus/length_stats.txt for length stats.
            max_length=maxl,
            pad_token=tokenizer.pad_token_id,
            pretrained_vocab_size=gist_llama.PRETRAINED_VOCAB_SIZE,
            check_correctness=True,
            token_name_dict=token_name_dict,
            update_tokens=update_tokens,
            start_markers=start_markers,
            num_token_per_prompt=args.model.num_token_per_prompt,
            use_scalar_encode=args.model.use_scalar_encode,
            use_end_marker=args.model.use_end_marker,
            use_functional_token = args.model.use_functional_token,
            add_ce_loss=args.model.add_ce_loss,
            autoregressive_attn_mask=args.model.autoregressive_attn_mask
                )
    else:
        assert False, "should be is_llama or is_t5"

    # Initialize trainer
    custom_callbacks = []
    if args.wandb.log:
        custom_callbacks.append(CustomWandbCallback(args))
    if args.training.evaluate_before_train:
        custom_callbacks.append(EvaluateFirstStepCallback())

    trainer = GistSeq2SeqTrainer(
        model=model,
        args=args.training,
        train_dataset=train_dataset if args.training.do_train else None,
        eval_dataset=eval_dataset if args.training.do_eval else None,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics
        if args.training.do_eval and not is_torch_tpu_available()
        else None,
        preprocess_logits_for_metrics=None,
        callbacks=custom_callbacks,
    )

    if args.training.fp16:
        trainer.scaler = torch.cuda.amp.GradScaler(init_scale=2.**14)

    ###### Training ######
    if args.training.do_train:
        checkpoint = None
        if args.training.resume_from_checkpoint is not None:
            checkpoint = args.training.resume_from_checkpoint
        elif last_checkpoint is not None:
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)

        if inference:
            model.manual_eval = True
            metrics = trainer.evaluate()
            np.save(model.output_dir + "/labels.npy", model.true_val)
            np.save(model.output_dir + "/preds.npy", model.pred_val)
            exit()
        
        trainer.save_model()  # Saves the tokenizer too for easy upload

        metrics = train_result.metrics

        max_train_samples = (
            args.data.max_train_samples
            if args.data.max_train_samples is not None
            else len(train_dataset)
        )
        metrics["train_samples"] = min(max_train_samples, len(train_dataset))

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

        try:
            torch.save(model.model.augmented_embedder.state_dict(), args.training.output_dir + "/augmented_embedder.pth")
            np.save(args.training.output_dir + "/embedder_weights.npy", model.model.augmented_embedder.embedding.weight.data.clone().detach().cpu().numpy())
        except:
            pass
            
        try:
            torch.save(model.lm_head_reg.state_dict(), args.training.output_dir + "/lm_head_reg.pth")
            np.save(args.training.output_dir + "/lm_head_reg.npy", model.lm_head_reg.weight.data.detach().cpu().numpy())
        except:
            pass
        
        if args.model.use_scalar_encode:
            np.save(args.training.output_dir + "/encode_reg.npy",model.model.augmented_embedder.encode_reg.weight.data.detach().cpu().numpy())



if __name__ == "__main__":
    main()


